# Supervised Fine-tuning Trainer 

Supervised fine-tuning (or SFT for short) is a crucial step in RLHF. In TRL we provide an easy-to-use API to create your SFT models and train them with few lines of code on your dataset.

## Quickstart

If you have a dataset hosted on the 🤗 Hub, you can easily fine-tune your SFT model using [`SFTTrainer`] from TRL. Let us assume your dataset is `imdb`, the text you want to predict is inside the `text` field of the dataset, and you want to fine-tune the `facebook/opt-350m` model. 
The following code-snippet takes care of all the data pre-processing and training for you:

```python
from datasets import load_dataset
from trl import SFTTrainer

dataset = load_dataset("imdb", split="train")

trainer = SFTTrainer(
    "facebook/opt-350m",
    train_dataset=dataset,
    dataset_text_field="text",
    max_seq_length=512,
)
trainer.train()
```
Make sure to pass a correct value for `max_seq_length` as the default value will be set to `min(tokenizer.model_max_length, 1024)`.

You can also construct a model outside of the trainer and pass it as follows:

```python
from transformers import AutoModelForCausalLM
from datasets import load_dataset
from trl import SFTTrainer

dataset = load_dataset("imdb", split="train")

model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")

trainer = SFTTrainer(
    model,
    train_dataset=dataset,
    dataset_text_field="text",
    max_seq_length=512,
)

trainer.train()
```

The above snippets will use the default training arguments from the [`transformers.TrainingArguments`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments) class. If you want to modify that, make sure to create your own `TrainingArguments` object and pass it to the [`SFTTrainer`] constructor as it is done on the [`supervised_finetuning.py` script](https://github.com/lvwerra/trl/blob/main/examples/stack_llama/scripts/supervised_finetuning.py) on the stack-llama example.

## Advanced usage

### Format your input prompts

For instruction fine-tuning, it is quite common to have two columns inside the dataset: one for the prompt & the other for the response. 
This allows people to format examples like [Stanford-Alpaca](https://github.com/tatsu-lab/stanford_alpaca) did as follows:
```bash
Below is an instruction ...

### Instruction
{prompt}

### Response:
{completion}
```
Let us assume your dataset has two fields, `question` and `answer`. Therefore you can just run:
```python
...
def formatting_prompts_func(example):
    text = f"### Question: {example['question']}\n ### Answer: {example['answer']}"
    return text

trainer = SFTTrainer(
    model,
    train_dataset=dataset,
    formatting_func=formatting_prompts_func,
)

trainer.train()
```

### Packing dataset ([`ConstantLengthDataset`])

[`SFTTrainer`] supports _example packing_, where multiple short examples are packed in the same input sequence to increase training efficiency. This is done with the [`ConstantLengthDataset`] utility class that returns constant length chunks of tokens from a stream of examples. To enable the usage of this dataset class, simply pass `packing=True` to the [`SFTTrainer`] constructor.

```python
...

trainer = SFTTrainer(
    "facebook/opt-350m",
    train_dataset=dataset,
    dataset_text_field="text",
    packing=True
)

trainer.train()
```

#### Customize your prompts using packed dataset

If your dataset has several fields that you want to combine, for example if the dataset has `question` and `answer` fields and you want to combine them, you can pass a formatting function to the trainer that will take care of that. For example:

```python
def formatting_func(example):
    text = f"### Question: {example['question']}\n ### Answer: {example['answer']}"
    return text

trainer = SFTTrainer(
    "facebook/opt-350m",
    train_dataset=dataset,
    packing=True,
    formatting_func=formatting_func
)

trainer.train()
```
You can also customize the [`ConstantLengthDataset`] much more by directly passing the arguments to the [`SFTTrainer`] constructor. Please refer to that class' signature for more information.

### Control over the pretrained model

You can directly pass the kwargs of the `from_pretrained()` method to the [`SFTTrainer`]. For example, if you want to load a model in a different precision, analoguous to 

```python
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.bfloat16)

```python
...

trainer = SFTTrainer(
    "facebook/opt-350m",
    train_dataset=dataset,
    dataset_text_field="text",
    torch_dtype=torch.bfloat16,
)

trainer.train()
```
Note that all keyword arguments of `from_pretrained()` are supported. 

### Training adapters

We also support a tight integration with 🤗 PEFT library so that any user can conveniently train adapters and share them on the Hub instead of training the entire model

```python
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig

dataset = load_dataset("imdb", split="train")

peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

trainer = SFTTrainer(
    "EleutherAI/gpt-neo-125m",
    train_dataset=dataset,
    dataset_text_field="text",
    peft_config=peft_config
)

trainer.train()
```

Note that in case of training adapters, we manually add a saving callback to automatically save the adapters only:
```python
class PeftSavingCallback(TrainerCallback):
    def on_save(self, args, state, control, **kwargs):
        checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
        kwargs["model"].save_pretrained(checkpoint_path)

        if "pytorch_model.bin" in os.listdir(checkpoint_path):
            os.remove(os.path.join(checkpoint_path, "pytorch_model.bin"))
```
If you want to add more callbacks, make sure to add this one as well to properly save the adapters only during training.
```python
...

callbacks = [YourCustomCallback(), PeftSavingCallback()]

trainer = SFTTrainer(
    "EleutherAI/gpt-neo-125m",
    train_dataset=dataset,
    dataset_text_field="text",
    torch_dtype=torch.bfloat16,
    peft_config=peft_config,
    callbacks=callbacks
)

trainer.train()
```

### Training adapters with base 8 bit models 

For that you need to first load your 8bit model outside the Trainer and pass a `PeftConfig` to the trainer. For example:

```python
...

peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

model = AutoModelForCausalLM.from_pretrained(
    "EleutherAI/gpt-neo-125m",
    load_in_8bit=True,
    device_map="auto",
)

trainer = SFTTrainer(
    model,
    train_dataset=dataset,
    dataset_text_field="text",
    torch_dtype=torch.bfloat16,
    peft_config=peft_config,
)

trainer.train()
```

## Best practices

Pay attention to the following best practices when training a model with that trainer:

- [`SFTTrainer`] always pads by default the sequences to the `max_seq_length` argument of the [`SFTTrainer`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide default value, so there is a check to retrieve the minimum between 2048 and that value. Make sure to check it before training.
- For training adapters in 8bit, you might need to tweak the arguments of the `prepare_model_for_int8_training` method from PEFT, hence we advise users to use `prepare_in_int8_kwargs` field, or create the `PeftModel` outside the [`SFTTrainer`] and pass it.
- For a more memory-efficient training using adapters, you can load the base model in 8bit, for that simply add `load_in_8bit` argument when creating the [`SFTTrainer`], or create a base model in 8bit outside the trainer and pass it.  
- If you create a model outside the trainer, make sure to not pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method.

## SFTTrainer

[[autodoc]] SFTTrainer

## ConstantLengthDataset

[[autodoc]] trainer.ConstantLengthDataset